{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "QEydh1KI7kPU"
   },
   "source": [
    "# Training the BERTSUM model \n",
    "The code for training a BERTSUM model is open-sourced by the researchers of BERTSUM and it is available at https://github.com/nlpyang/BertSum. In this section, let us explore this and learn how to train the BERTSUM model. We will train the BERTSUM model on the CNN/DailyMail news dataset. \n",
    "\n",
    "#### Makse sure to run the notebook in GPU \n",
    "\n",
    "First, let us install the necessary libraries: \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true,
    "id": "CijQP8eX7kPg"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!pip install pytorch_pretrained_bert\n",
    "!pip install torch==1.1.0 pytorch_transformers tensorboardX multiprocess pyrouge\n",
    "!pip install googleDriveFileDownloader"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "6zefDK647kPh"
   },
   "source": [
    "\n",
    "If you are working with Google's colab, switch to the content directory with the following code: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1RU3R-tN7kPh",
    "outputId": "a2b81637-54dc-4b41-bc43-8726adc3e6cb"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content\n"
     ]
    }
   ],
   "source": [
    "cd /content/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "znnhL1XW7kPj"
   },
   "source": [
    "\n",
    "Clone the BERTSUM repository: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "kOdMb3IM7kPl",
    "outputId": "0a62f0e4-a199-4b8d-f8a3-a1a65ff33e70"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Cloning into 'BertSum'...\n",
      "remote: Enumerating objects: 301, done.\u001b[K\n",
      "remote: Total 301 (delta 0), reused 0 (delta 0), pack-reused 301\u001b[K\n",
      "Receiving objects: 100% (301/301), 15.03 MiB | 21.35 MiB/s, done.\n",
      "Resolving deltas: 100% (174/174), done.\n"
     ]
    }
   ],
   "source": [
    "!git clone https://github.com/nlpyang/BertSum.git"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "0HGK0q2y7kPo"
   },
   "source": [
    "\n",
    "Now switch to the bert_data directory: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "5qTMNpxt7kPq",
    "outputId": "f9d575ce-13bf-4516-e06b-0a51e4a86286"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/BertSum/bert_data\n"
     ]
    }
   ],
   "source": [
    "cd /content/BertSum/bert_data/"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Rrw-CM-77kPs"
   },
   "source": [
    "\n",
    "The researchers have also made available the preprocessed CNN/DailyMail news dataset. So, first, let us download them: \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "oYUWbSd47kPv",
    "outputId": "a8f179f4-7656-405c-fb60-e27bad564aad"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Download is starting\n",
      "bertsum_data.zip\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "0"
      ]
     },
     "execution_count": 5,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from googleDriveFileDownloader import googleDriveFileDownloader\n",
    "gdrive = googleDriveFileDownloader()\n",
    "gdrive.downloadFile(\"https://drive.google.com/uc?id=1x0d61LP9UAN389YN00z0Pv-7jQgirVg6&export=download\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "fG4aO-NI7kPw"
   },
   "source": [
    "\n",
    "Unzip the downloaded dataset: \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "collapsed": true,
    "id": "jMajhZEt7kPw"
   },
   "outputs": [],
   "source": [
    "%%capture\n",
    "!unzip /content/BertSum/bert_data/bertsum_data.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "B2MtFAUo7kP1"
   },
   "source": [
    "\n",
    "Switch to src directory: "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "NykOSG697kP4",
    "outputId": "a0cc3509-6111-4fdb-8975-2f5a9a05bd23"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/BertSum/src\n"
     ]
    }
   ],
   "source": [
    "cd /content/BertSum/src"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "F02kTSMI7kP6"
   },
   "source": [
    "\n",
    "Now train the BERTSUM model. In the following code, the argument -encoder classifier implies that we are training the BERTSUM model with a classifier: \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "4Dmb4zk27kP7",
    "outputId": "92227892-c66d-41f5-8998-b52df507eafe"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[2020-12-30 23:26:42,695 INFO] Device ID 0\n",
      "[2020-12-30 23:26:42,695 INFO] Device cuda\n",
      "[2020-12-30 23:26:42,817 INFO] https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz not found in cache, downloading to /tmp/tmpj9o62s7t\n",
      "100% 407873900/407873900 [00:05<00:00, 76742717.80B/s]\n",
      "[2020-12-30 23:26:48,238 INFO] copying /tmp/tmpj9o62s7t to cache at ../temp/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
      "[2020-12-30 23:26:54,952 INFO] creating metadata file for ../temp/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
      "[2020-12-30 23:26:54,953 INFO] removing temp file /tmp/tmpj9o62s7t\n",
      "[2020-12-30 23:26:54,989 INFO] loading archive file https://s3.amazonaws.com/models.huggingface.co/bert/bert-base-uncased.tar.gz from cache at ../temp/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba\n",
      "[2020-12-30 23:26:54,989 INFO] extracting archive file ../temp/9c41111e2de84547a463fd39217199738d1e3deb72d4fec4399e6e241983c6f0.ae3cef932725ca7a30cdcb93fc6e09150a55e2a130ec7af63975a16c153ae2ba to temp dir /tmp/tmpscbcx33s\n"
     ]
    }
   ],
   "source": [
    "!python train.py -mode train -encoder classifier -dropout 0.1 -bert_data_path ../bert_data/cnndm -model_path ../models/bert_classifier -lr 2e-3 -visible_gpus 0 -gpu_ranks 0 -world_size 1 -report_every 50 -save_checkpoint_steps 1000 -batch_size 3000 -decay_method noam -train_steps 50 -accum_count 2 -log_file ../logs/bert_classifier -use_interval true -warmup_steps 10000\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "dJIo-Eja7kP8"
   },
   "source": [
    "That's it. During training, you can see how the ROUGE score changes over a series of epochs. In the next chapter, we will learn how to apply BERT for other languages using multilingual-BERT."
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "name": "6.07. Training the BERTSUM model .ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
